python, decision treeのtree graph plot¶
動機¶
- decision treeのtree plotのためにいろいろ調べた
- kerasで使うので久しぶりに調べた(2017-02-16)
- 平面で境界線のplotはmatplotlibで書くので、この記事を閉じてよい
scikit-learnでdecision tree¶
- http://scikit-learn.org/stable/modules/tree.html
- http://scikit-learn.org/stable/modules/generated/sklearn.tree.DecisionTreeClassifier.html
plot用APIのWIP¶
https://github.com/scikit-learn/scikit-learn/pull/6380
library¶
pydot¶
- https://pypi.python.org/pypi/pydot
- https://pypi.python.org/pypi/pydot2/1.0.32
- https://pypi.python.org/pypi/pydot3
- https://pypi.python.org/pypi/pydot-ng
pydot2¶
上記でいろいろあげたけど、pydotを使うならこれでよさそう- pydotと同じ作者
- source codeは公開されているけどrepositoryがない?
- sklearnのexampleにも使われているが、開発継続されていない?
- python3でやってみたけど、エラー
from sklearn.externals.six import StringIO as SkStringIO
from IPython.display import Image
import pydot
dot_data = SkStringIO()
export_graphviz(
tree_clf, out_file=dot_data,
feature_names=X.columns,
class_names=["0", "1"],
filled=True, rounded=True,
special_characters=True)
graph = pydot.graph_from_dot_data(dot_data.getvalue())
Image(graph.create_png())
pydot3¶
- kerasでvisualizationを試したときに使った
- python3で動いた
- forkではなくportしてcompatibleにしたcommitをしている
- いろいろあるrepositoryの中で一番直近に更新されている
pydot-ng¶
- pydotのrepositoryのhistoryを引き継いでいる個人repositoryをforkしたもの
- organizationになっている
import pydot_ng as pydot
としてmodule名の代用をしないといけないので既存moduleではだめっぽそう- https://github.com/fchollet/keras/commit/e5d3abdf09d8c281ca8817b6292a044673ba3007
- kerasではこのmoduleを優先して(pydotとして)importしているのでこれでよい
- なければpydotがimportされるのでpydot3でも大丈夫だった
networkx¶
- http://pypi.python.org/pypi/networkx/
- GraphVizは必須ではなくOption
PyGraphviz¶
- http://pygraphviz.github.io
- networkxと連携できる
graphviz¶
- https://pypi.python.org/pypi/graphviz
- plot用APIのWIPではこれが使われていた
- こちらは問題なく動いた
pip install graphviz
graphviz_tree_plot.py¶
from sklearn.externals.six import StringIO as SkStringIO
from IPython.display import Image
import graphviz
def tree_plot(decision_tree, width=500, height=500, max_depth=None,
feature_names=None, class_names=None, label='all',
filled=False, leaves_parallel=False, impurity=True,
node_ids=False, proportion=False, rotate=False,
rounded=False, special_characters=False):
in_memory_dot_file = SkStringIO()
export_graphviz(
decision_tree, out_file=in_memory_dot_file, max_depth=max_depth,
feature_names=feature_names, class_names=class_names, label=label,
filled=filled, leaves_parallel=leaves_parallel, impurity=impurity,
node_ids=node_ids, proportion=proportion, rotate=rotate,
rounded=rounded, special_characters=special_characters)
src = graphviz.Source(in_memory_dot_file.getvalue())
return Image(src.pipe(format='png'), height=height, width=width)
tree_plot(tree_clf,
feature_names=X.columns,
class_names=["d", "s"],
# class_names={1: "s", 0: "d"},
filled=True,
rounded=True,
special_characters=True
)
比較¶
http://plaza.rakuten.co.jp/kugutsushi/diary/200711080000/